I am using PPO in the context of multi-agent RL. I was wondering if PyTorch has a way of handling when hidden states should be reinitialized to zeros.
What I have found is this implementation:
def forward(self, x, hxs, masks): if x.size(0) == hxs.size(0): x, hxs = self.rnn(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) x = x.squeeze(0) hxs = hxs.transpose(0, 1) else: # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) N = hxs.size(0) T = int(x.size(0) / N) # unflatten x = x.view(T, N, x.size(1)) # Same deal with masks masks = masks.view(T, N) # Let's figure out which steps in the sequence have a zero for any agent # We will always assume t=0 has a zero in it as that makes the logic cleaner has_zeros = ((masks[1:] == 0.0) .any(dim=-1) .nonzero() .squeeze() .cpu()) # +1 to correct the masks[1:] if has_zeros.dim() == 0: # Deal with scalar has_zeros = [has_zeros.item() + 1] else: has_zeros = (has_zeros + 1).numpy().tolist() # add t=0 and t=T to the list has_zeros = [0] + has_zeros + [T] hxs = hxs.transpose(0, 1) outputs = [] for i in range(len(has_zeros) - 1): # We can now process steps that don't have any zeros in masks together! # This is much faster start_idx = has_zeros[i] end_idx = has_zeros[i + 1] temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) outputs.append(rnn_scores) # assert len(outputs) == T # x is a (T, N, -1) tensor x = torch.cat(outputs, dim=0) # flatten x = x.reshape(T * N, -1) hxs = hxs.transpose(0, 1)
submitted by /u/No_Possibility_7588
[link] [comments]
( 2
min )